-
Notifications
You must be signed in to change notification settings - Fork 251
Add SDPA backend tests and refactor generate.py #1477
base: main
Are you sure you want to change the base?
Conversation
Push backend manager into caller
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1477
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit aee35a5 with merge base 083fdaf ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Add tests for backends
print out parameters during execution
The new attention-backend option has a new interesting wrinkle: it only specifies one backend, but if that kernel does not work for paramters, things fail. At a minimum we may always to have specify "math" as a backup? (see specific error below) Also, should we have a backend option "auto" that relies simply on the sdpa logic to find the best kernel?
|
Allow math as fallback
In a perfect case, we would have a support matrix for configurations we are confident in (with backing tests), and spit out right at the start, a warning when arguments trek into territory that we are not guaranteeing cc: @yanbing-j
I'm cautious of torchchat's use of fallbacks (and in other pytorch projects in general, but that's a different bag of worms), since it opens the door for misinterpretation; Users thinking we are doing what they ask, when in reality it's succeeding with a different config.
Thoughts?
This sounds amazing. Having an explicit list with justifications for the rankings also let's us catch if numerics look fishy |
Confirmed that for CPU side, only math and flash_attention can be chosen as attention backend. And so far, flash_attention does not encounter the issue of |
I'm not so enamored of the "fail early". "Fail early" turns into a "no configuration ever works unless perfect", and that quickly becomes the empty set with overly fragile systems. I agree with the "misleading" configuration issue. A user might feel that giving them OTOH, if we don't posit that things will just work. The moral equivalent in my mind is the example of a compiler that will only compile programs it has been tested on, because everything else can't be expected to work. That's not useful for the end user, and transfer a lot of responsibility to them -- responsibility that they don't have enough data to make a meaningful call. Imagine if every compiler returned "Unknown program. Compiled code not guaranteed to be correct." and thereby transfers the responsibility to the compiler user. |
Tests from pytorch#1477 only, without the generate.py refactor
All good points!! Convinced that refusing to run, is a bad experience For the sake of avoiding indirection, I suggest we execute the config as requested; allowing any errors or bad perf surface to the caller, potentially failing at execution time.
For the sake of testing we can omit or check for the failure cases |
#1480 runs a matrix of all tests against the attention backend options. In a nutshell, only MATH is guaranteeed to handle all inputs, so if we exclude MATH we'll naturally get scenarios where we can't run correctly. TBH, given the call site is passing a mask that the flash attention kernel does not support, I don't think we can ever get it to use flash attention. |
Push backend manager into caller